use re_entity_db::InstancePathHash;
use re_log_types::{EntityPath, Instance};
use re_sdk_types::Archetype as _;
use re_sdk_types::archetypes::{
    CoordinateFrame, InstancePoses3D, Pinhole, Transform3D, TransformAxes3D,
};
use re_sdk_types::components::{AxisLength, ShowLabels};
use re_view::latest_at_with_blueprint_resolved_data;
use re_viewer_context::{
    IdentifiedViewSystem, RequiredComponents, ViewContext, ViewContextCollection, ViewQuery,
    ViewSystemExecutionError, VisualizerExecutionOutput, VisualizerQueryInfo, VisualizerSystem,
};

use super::{SpatialViewVisualizerData, UiLabel, UiLabelStyle, UiLabelTarget};
use crate::contexts::TransformTreeContext;
use crate::view_kind::SpatialViewKind;
use crate::visualizers::utilities::transform_info_for_entity_or_report_error;

pub struct TransformAxes3DVisualizer(SpatialViewVisualizerData);

impl Default for TransformAxes3DVisualizer {
    fn default() -> Self {
        Self(SpatialViewVisualizerData::new(Some(
            SpatialViewKind::ThreeD,
        )))
    }
}

impl IdentifiedViewSystem for TransformAxes3DVisualizer {
    fn identifier() -> re_viewer_context::ViewSystemIdentifier {
        "TransformAxes3D".into()
    }
}

impl VisualizerSystem for TransformAxes3DVisualizer {
    fn visualizer_query_info(&self) -> VisualizerQueryInfo {
        let mut query_info = VisualizerQueryInfo::from_archetype::<TransformAxes3D>();

        // Make this visualizer available for any entity with Transform3D components
        query_info.required = RequiredComponents::Any(
            Transform3D::all_component_identifiers()
                .chain(CoordinateFrame::all_component_identifiers())
                .chain(InstancePoses3D::all_component_identifiers())
                .chain(Pinhole::all_component_identifiers())
                .chain(TransformAxes3D::all_component_identifiers())
                .collect(),
        );

        query_info
    }

    fn execute(
        &mut self,
        ctx: &ViewContext<'_>,
        query: &ViewQuery<'_>,
        context_systems: &ViewContextCollection,
    ) -> Result<VisualizerExecutionOutput, ViewSystemExecutionError> {
        let mut output = VisualizerExecutionOutput::default();

        let transforms = context_systems.get::<TransformTreeContext>()?;

        let latest_at_query = re_chunk_store::LatestAtQuery::new(query.timeline, query.latest_at);

        // Counting all transforms ahead of time is a bit wasteful, but we also don't expect a huge amount,
        // so let re_renderer's allocator internally decide what buffer sizes to pick & grow them as we go.
        let mut line_builder = re_renderer::LineDrawableBuilder::new(ctx.viewer_ctx.render_ctx());
        line_builder.radius_boost_in_ui_points_for_outlines(
            re_view::SIZE_BOOST_IN_POINTS_FOR_LINE_OUTLINES,
        );

        for data_result in query.iter_visible_data_results(Self::identifier()) {
            let Some(transform_info) = transform_info_for_entity_or_report_error(
                transforms,
                &data_result.entity_path,
                &mut output,
            ) else {
                continue;
            };

            // Determine which transforms to draw axes at.
            // For pinhole cameras, we draw at the pinhole location only.
            // For normal entities, we iterate over all instance poses.
            let transforms_to_draw: smallvec::SmallVec<[glam::Affine3A; 1]> =
                if let Some(pinhole_tree_root_info) =
                    transforms.pinhole_tree_root_info(transform_info.tree_root())
                {
                    if transform_info.tree_root()
                        == re_tf::TransformFrameIdHash::from_entity_path(&data_result.entity_path)
                    {
                        // We're _at_ that pinhole.
                        // Don't apply the from-2D transform, stick with the last known 3D.
                        smallvec::smallvec![
                            pinhole_tree_root_info
                                .parent_root_from_pinhole_root
                                .as_affine3a()
                        ]
                    } else {
                        // We're inside a 2D space. But this is a 3D transform.
                        // Something is wrong here and this is not the right place to report it.
                        // Better just don't draw the axis!
                        continue;
                    }
                } else {
                    transform_info
                        .target_from_instances()
                        .iter()
                        .map(|t| t.as_affine3a())
                        .collect()
                };

            let axis_length_identifier = TransformAxes3D::descriptor_axis_length().component;
            let show_frame_identifier = TransformAxes3D::descriptor_show_frame().component;

            // Note, we use this interface instead of `data_result.latest_at_with_blueprint_resolved_data` to avoid querying
            // for a bunch of unused components. The actual transform data comes out of the context manager and can't be
            // overridden via blueprint anyways.
            let results = latest_at_with_blueprint_resolved_data(
                ctx,
                None,
                &latest_at_query,
                data_result,
                [axis_length_identifier, show_frame_identifier],
                false,
            );

            let axis_length: f32 = results
                .get_mono_with_fallback::<AxisLength>(axis_length_identifier)
                .into();

            if axis_length == 0.0 {
                // Don't draw axis and don't add to the bounding box!
                continue;
            }

            let show_frame: bool = results
                .get_mono_with_fallback::<ShowLabels>(show_frame_identifier)
                .into();

            if show_frame {
                // Add label at the center of each transform instance if `show_frame` is enabled.
                let frame_id_hash =
                    transforms.transform_frame_id_for(data_result.entity_path.hash());

                if let Some(frame_id) = transforms.lookup_frame_id(frame_id_hash) {
                    self.0
                        .ui_labels
                        .extend(transforms_to_draw.iter().map(|transform| UiLabel {
                            text: frame_id.to_string(),
                            style: UiLabelStyle::Default,
                            target: UiLabelTarget::Position3D(
                                transform.transform_point3(glam::Vec3::ZERO),
                            ),
                            labeled_instance: InstancePathHash::entity_all(
                                &data_result.entity_path,
                            ),
                        }));
                } else {
                    // It should not be possible to hit this path and frame id hashes are not something that
                    // we should ever expose to our users, so let's add a debug assert for good measure.
                    debug_assert!(
                        false,
                        "[DEBUG ASSERT] unable to resolve frame id hash {frame_id_hash:?}"
                    );
                    output.report_error_for(
                        data_result.entity_path.clone(),
                        format!("Could not resolve frame id hash {frame_id_hash:?}"),
                    );
                }
            }

            // Draw axes for each instance
            for (instance_index, world_from_obj) in transforms_to_draw.iter().enumerate() {
                // Only add the center to the bounding box - the lines may be dependent on the bounding box, causing a feedback loop otherwise.
                self.0.add_bounding_box(
                    data_result.entity_path.hash(),
                    macaw::BoundingBox::ZERO,
                    *world_from_obj,
                );

                // Check for per-instance highlighting, fall back to overall entity highlighting
                let outline_mask = query
                    .highlights
                    .entity_outline_mask(data_result.entity_path.hash())
                    .instances
                    .get(&Instance::from(instance_index as u64))
                    .copied()
                    .unwrap_or_else(|| {
                        query
                            .highlights
                            .entity_outline_mask(data_result.entity_path.hash())
                            .overall
                    });

                add_axis_arrows(
                    ctx.tokens(),
                    &mut line_builder,
                    *world_from_obj,
                    Some(&data_result.entity_path),
                    axis_length,
                    outline_mask,
                    instance_index as u64,
                );
            }
        }

        Ok(output.with_draw_data([line_builder.into_draw_data()?.into()]))
    }

    fn data(&self) -> Option<&dyn std::any::Any> {
        Some(self.0.as_any())
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

pub fn add_axis_arrows(
    tokens: &re_ui::DesignTokens,
    line_builder: &mut re_renderer::LineDrawableBuilder<'_>,
    world_from_obj: glam::Affine3A,
    ent_path: Option<&EntityPath>,
    axis_length: f32,
    outline_mask_ids: re_renderer::OutlineMaskPreference,
    instance_index: u64,
) {
    use re_renderer::renderer::LineStripFlags;

    // TODO(andreas): It would be nice if could display the ViewCoordinates axis names (left/right/up) as a tooltip on hover.

    let line_radius = re_renderer::Size::new_ui_points(1.0);

    let mut line_batch = line_builder
        .batch(ent_path.map_or("axis_arrows".to_owned(), |p| p.to_string()))
        .world_from_obj(world_from_obj)
        .triangle_cap_length_factor(10.0)
        .triangle_cap_width_factor(3.0)
        .outline_mask_ids(outline_mask_ids)
        .picking_object_id(re_renderer::PickingLayerObjectId(
            ent_path.map_or(0, |p| p.hash64()),
        ));
    let picking_instance_id = re_renderer::PickingLayerInstanceId(instance_index);

    line_batch
        .add_segment(glam::Vec3::ZERO, glam::Vec3::X * axis_length)
        .radius(line_radius)
        .color(tokens.axis_color_x)
        .flags(LineStripFlags::FLAG_CAP_END_TRIANGLE | LineStripFlags::FLAG_CAP_START_ROUND)
        .picking_instance_id(picking_instance_id);
    line_batch
        .add_segment(glam::Vec3::ZERO, glam::Vec3::Y * axis_length)
        .radius(line_radius)
        .color(tokens.axis_color_y)
        .flags(LineStripFlags::FLAG_CAP_END_TRIANGLE | LineStripFlags::FLAG_CAP_START_ROUND)
        .picking_instance_id(picking_instance_id);
    line_batch
        .add_segment(glam::Vec3::ZERO, glam::Vec3::Z * axis_length)
        .radius(line_radius)
        .color(tokens.axis_color_z)
        .flags(LineStripFlags::FLAG_CAP_END_TRIANGLE | LineStripFlags::FLAG_CAP_START_ROUND)
        .picking_instance_id(picking_instance_id);
}
